{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-12-17T00:04:18.609576Z",
     "start_time": "2020-12-17T00:04:17.663538Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../../')\n",
    "import numpy as np\n",
    "import torch\n",
    "from torchvision import transforms\n",
    "from matplotlib import pyplot as plt\n",
    "from tqdm.notebook import tqdm\n",
    "from PIL import Image, ImageOps\n",
    "from scipy.spatial.transform import Rotation\n",
    "import pandas as pd\n",
    "from scipy.spatial import distance\n",
    "import time\n",
    "import os\n",
    "import math\n",
    "import scipy.io as sio\n",
    "from utils.renderer import Renderer\n",
    "from utils.image_operations import expand_bbox_rectangle\n",
    "from utils.pose_operations import get_pose\n",
    "from img2pose import img2poseModel\n",
    "from model_loader import load_model\n",
    "\n",
    "np.set_printoptions(suppress=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dataset annotations "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-12-17T00:04:18.618288Z",
     "start_time": "2020-12-17T00:04:18.611543Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "dataset_path = \"AFLW2000_annotations.txt\"\n",
    "test_dataset = pd.read_csv(dataset_path, delimiter=\" \", header=None)\n",
    "test_dataset = np.asarray(test_dataset).squeeze()  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Declare useful functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-12-17T00:04:18.632686Z",
     "start_time": "2020-12-17T00:04:18.620244Z"
    }
   },
   "outputs": [],
   "source": [
    "def bb_intersection_over_union(boxA, boxB):\n",
    "    xA = max(boxA[0], boxB[0])\n",
    "    yA = max(boxA[1], boxB[1])\n",
    "    xB = min(boxA[2], boxB[2])\n",
    "    yB = min(boxA[3], boxB[3])\n",
    "\n",
    "    interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)\n",
    "    boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)\n",
    "    boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)\n",
    "    iou = interArea / float(boxAArea + boxBArea - interArea)\n",
    "    return iou\n",
    "\n",
    "def render_plot(img, pose_pred):\n",
    "    (w, h) = img.size\n",
    "    image_intrinsics = np.array([[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]])\n",
    "\n",
    "    trans_vertices = renderer.transform_vertices(img, [pose_pred])\n",
    "    img = renderer.render(img, trans_vertices, alpha=1)    \n",
    "\n",
    "    plt.figure(figsize=(8, 8))        \n",
    "\n",
    "    plt.imshow(img)        \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the renderer for visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-12-17T00:04:18.654820Z",
     "start_time": "2020-12-17T00:04:18.634517Z"
    }
   },
   "outputs": [],
   "source": [
    "renderer = Renderer(\n",
    "    vertices_path=\"../../pose_references/vertices_trans.npy\", \n",
    "    triangles_path=\"../../pose_references/triangles.npy\"\n",
    ")\n",
    "\n",
    "threed_points = np.load('../../pose_references/reference_3d_68_points_trans.npy')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load model weights and pose mean and std deviation\n",
    "To test other models, change MODEL_PATH along the the POSE_MEAN and POSE_STDDEV used for training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-12-17T00:04:21.564127Z",
     "start_time": "2020-12-17T00:04:18.656885Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model will use 1 GPUs!\n"
     ]
    }
   ],
   "source": [
    "transform = transforms.Compose([transforms.ToTensor()])\n",
    "\n",
    "DEPTH = 18\n",
    "MAX_SIZE = 1400\n",
    "MIN_SIZE = 600\n",
    "\n",
    "POSE_MEAN = \"../../models/WIDER_train_pose_mean_v1.npy\"\n",
    "POSE_STDDEV = \"../../models/WIDER_train_pose_stddev_v1.npy\"\n",
    "MODEL_PATH = \"../../models/img2pose_v1.pth\"\n",
    "\n",
    "pose_mean = np.load(POSE_MEAN)\n",
    "pose_stddev = np.load(POSE_STDDEV)\n",
    "\n",
    "img2pose_model = img2poseModel(\n",
    "    DEPTH, MIN_SIZE, MAX_SIZE, \n",
    "    pose_mean=pose_mean, pose_stddev=pose_stddev,\n",
    "    threed_68_points=threed_points,\n",
    "    rpn_pre_nms_top_n_test=500,\n",
    "    rpn_post_nms_top_n_test=10,\n",
    ")\n",
    "load_model(img2pose_model.fpn_model, MODEL_PATH, cpu_mode=str(img2pose_model.device) == \"cpu\", model_only=True)\n",
    "img2pose_model.evaluate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run AFLW2000-3D evaluation\n",
    "To visualize the predictions, change visualize to True and change total_imgs to the amount of images desired."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-12-17T00:06:14.920520Z",
     "start_time": "2020-12-17T00:05:12.688306Z"
    },
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fe06ec4003be46b1a27870486f732d3b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Model failed on 0 images\n",
      "Yaw: 4.417 Pitch: 5.569 Roll: 2.822 MAE: 4.269\n",
      "H. Trans.: 0.032 V. Trans.: 0.046 Scale: 0.268 MAE: 0.115\n",
      "Average time 0.026328349609917465\n"
     ]
    }
   ],
   "source": [
    "visualize = False\n",
    "total_imgs = len(test_dataset)\n",
    "\n",
    "threshold = 0.0\n",
    "targets = []\n",
    "predictions = []\n",
    "\n",
    "total_failures = 0\n",
    "times = []\n",
    "\n",
    "for img_path in tqdm(test_dataset[:total_imgs]):\n",
    "    img = Image.open(img_path).convert(\"RGB\")\n",
    "    \n",
    "    image_name = os.path.split(img_path)[1]\n",
    "    \n",
    "    ori_img = img.copy()\n",
    "    \n",
    "    (w, h) = ori_img.size\n",
    "    image_intrinsics = np.array([[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]])\n",
    "    \n",
    "    mat_contents = sio.loadmat(img_path[:-4] + \".mat\")\n",
    "    target_points = np.asarray(mat_contents['pt3d_68']).T[:, :2]\n",
    "    \n",
    "    _, pose_target = get_pose(threed_points, target_points, image_intrinsics)\n",
    "    \n",
    "    target_bbox = expand_bbox_rectangle(w, h, 1.1, 1.1, target_points, roll=pose_target[2])\n",
    "    \n",
    "    pose_para = np.asarray(mat_contents['Pose_Para'])[0][:3]\n",
    "    pose_target_degrees = pose_para[:3] * (180 / math.pi)\n",
    "    \n",
    "    if np.any(np.abs(pose_target_degrees) > 99):\n",
    "        continue        \n",
    "        \n",
    "    poses = []\n",
    "    run_time = 0\n",
    "    for size in (400,):\n",
    "        img2pose_model.fpn_model.module.set_max_min_size(size, size)            \n",
    "        time1 = time.time()\n",
    "        res = img2pose_model.predict([transform(img)])\n",
    "        time2 = time.time()\n",
    "        run_time += (time2 - time1)\n",
    "\n",
    "        res = res[0]\n",
    "\n",
    "        bboxes = res[\"boxes\"].cpu().numpy().astype('float')\n",
    "        max_iou = 0\n",
    "        best_index = -1\n",
    "\n",
    "        for i in range(len(bboxes)):\n",
    "            if res[\"scores\"][i] > threshold:\n",
    "                bbox = bboxes[i]\n",
    "                pose_pred = res[\"dofs\"].cpu().numpy()[i].astype('float')\n",
    "                pose_pred = np.asarray(pose_pred.squeeze())   \n",
    "                iou = bb_intersection_over_union(bbox, target_bbox)\n",
    "\n",
    "                if iou > max_iou:\n",
    "                    max_iou = iou\n",
    "                    best_index = i    \n",
    "                    \n",
    "        if best_index >= 0:\n",
    "            bbox = bboxes[best_index]\n",
    "            pose_pred = res[\"dofs\"].cpu().numpy()[best_index].astype('float')\n",
    "            pose_pred = np.asarray(pose_pred.squeeze())\n",
    "\n",
    "            poses.append(pose_pred)\n",
    "            \n",
    "        if visualize and best_index >= 0:    \n",
    "            render_plot(ori_img.copy(), pose_pred)\n",
    "            \n",
    "    if len(poses) == 0:\n",
    "        total_failures += 1\n",
    "\n",
    "        continue\n",
    "        \n",
    "    poses = np.asarray(poses)\n",
    "    pose_pred = np.mean(poses, axis=0)\n",
    "    \n",
    "    times.append(run_time)\n",
    "    \n",
    "    targets.append(pose_target)\n",
    "    predictions.append(pose_pred)\n",
    "    \n",
    "pose_mae = np.mean(abs(np.asarray(predictions) - np.asarray(targets)), axis=0)\n",
    "threed_pose = pose_mae[:3] * (180 / math.pi)\n",
    "trans_pose = pose_mae[3:]\n",
    "\n",
    "print(f\"Model failed on {total_failures} images\")\n",
    "\n",
    "print(f\"Yaw: {threed_pose[1]:.3f} Pitch: {threed_pose[0]:.3f} Roll: {threed_pose[2]:.3f} MAE: {np.mean(threed_pose):.3f}\")\n",
    "print(f\"H. Trans.: {trans_pose[0]:.3f} V. Trans.: {trans_pose[1]:.3f} Scale: {trans_pose[2]:.3f} MAE: {np.mean(trans_pose):.3f}\")\n",
    "print(f\"Average time {np.mean(np.asarray(times))}\")"
   ]
  }
 ],
 "metadata": {
  "bento_stylesheets": {
   "bento/extensions/flow/main.css": true,
   "bento/extensions/kernel_selector/main.css": true,
   "bento/extensions/kernel_ui/main.css": true,
   "bento/extensions/new_kernel/main.css": true,
   "bento/extensions/system_usage/main.css": true,
   "bento/extensions/theme/main.css": true
  },
  "disseminate_notebook_id": {
   "notebook_id": "336305087387084"
  },
  "disseminate_notebook_info": {
   "bento_version": "20200629-000305",
   "description": "Visualization of smaller model pose and expression qualitative results (trial 4).\nResNet-18 with sum of squared errors weighted equally for both pose and expression.\nFC layers for both pose and expression are fc1 512x512 and fc2 512 x output (output is either 6 or 72).\n",
   "hide_code": false,
   "hipster_group": "",
   "kernel_build_info": {
    "error": "The file located at '/data/users/valbiero/fbsource/fbcode/bento/kernels/local/deep_3d_face_modeling/TARGETS' could not be found."
   },
   "no_uii": true,
   "notebook_number": "296232",
   "others_can_edit": false,
   "reviewers": "",
   "revision_id": "1126967097689153",
   "tags": "",
   "tasks": "",
   "title": "Updated Model Pose and Expression Qualitative Results"
  },
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "pytorch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.1"
  },
  "notify_time": "30"
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
